#ifndef LOSS_GENERATOR_H
#define LOSS_GENERATOR_H

#include "NN_data.h"
#include "App1_variable.h"
#include <gurobi_c++.h>
class Loss_generator {
	protected:
		problem pr;
		NN_data* nn_data;
		const User_parameter* user_parameter;
		struct app1_indata* indata;
		struct app1_solution* solution;
		GRBEnv env;
	public:
		Loss_generator(problem pr, NN_data*, User_parameter*, struct app1_indata*, struct app1_solution*);
		double ulti_best_loss_batch(int sample_order, normalclass nc, posterior p);
		double best_loss_batch(int sample_order, normalclass nc, posterior p);
		void get_solution(int sample_order, normalclass nc, posterior p);
		double calc_approx_loss_batch(int sample_order, normalclass nc, posterior p, int is_main = 0);
};

Loss_generator::Loss_generator(problem pr,NN_data* nn_data, User_parameter* user_parameter, struct app1_indata* indata, struct app1_solution* solution) {
	this->pr = pr;
	this->nn_data = nn_data;
	this->user_parameter = user_parameter;
	this->indata = indata;
	this->solution = solution;
}
double Loss_generator::ulti_best_loss_batch(int sample_order, normalclass nc, posterior p) { 	GRBModel model = GRBModel(env);
	GRBVar* x = model.addVars(n_label);
	GRBVar* sb = model.addVars(5);
	GRBVar* qb = model.addVars(5);
	GRBQuadExpr obj;
	for (int i = 0;i < n_label;i++)
	{
		x[i].set(GRB_DoubleAttr_LB, 0);
		obj += pr.k[i] * x[i] * x[i] - (pr.p[i] + pr.k[i] * this->nn_data->batch_ans[i][sample_order]) * x[i];
	}

	for (int j = 0;j < 5;j++)
	{
		GRBLinExpr tmp;
		GRBQuadExpr tmp2;
		for (int i = 0;i < n_label;i++)
		{
			tmp += pr.A[i][j] * x[i];
			tmp2 += x[i] * x[i];
		}
		sb[j].set(GRB_DoubleAttr_LB, 0);
		qb[j].set(GRB_DoubleAttr_LB, 0);
		model.addConstr(sb[j] >= tmp - pr.b[j] + qb[j]);
		model.addConstr(sb[j] >= tmp - pr.b[j]);
		model.addQConstr(qb[j] * qb[j] >= tmp2);
		obj += this->user_parameter->beta * sb[j];
	}
	model.setObjective(obj);
	model.set(GRB_IntParam_LogToConsole, 0);
	model.optimize();
	double loss = model.get(GRB_DoubleAttr_ObjVal);
	this->solution->ulti_best_loss[sample_order] = loss;

	return loss;
}
double Loss_generator::best_loss_batch(int sample_order, normalclass nc, posterior p) { 	GRBModel model = GRBModel(env);
	GRBVar* x = model.addVars(n_label);
	GRBVar* sx = model.addVars(n_label);
	GRBVar* sb = model.addVars(5);
	GRBVar* qb = model.addVars(5);
	GRBQuadExpr obj;
	for (int i = 0;i < n_label;i++)
	{
		double tmpth = p.postprob[i][0][1] < p.postprob[i][1][1] ? p.postprob[i][0][1] : p.postprob[i][1][1];
		x[i].set(GRB_DoubleAttr_LB, 0);
		obj += pr.k[i] * x[i] * x[i] - (pr.p[i] + pr.k[i] * this->nn_data->batch_ans[i][sample_order]) * x[i];
		sx[i].set(GRB_DoubleAttr_LB, 0);
		if (this->nn_data->batch_ans[i][sample_order]< 3)
			model.addConstr(sx[i] >= x[i]);
		obj += this->user_parameter->beta * sx[i];
	}
	
	for (int j = 0;j < 5;j++)
	{
		GRBLinExpr tmp;
		GRBQuadExpr tmp2;
		for (int i = 0;i < n_label;i++)
		{
			tmp += pr.A[i][j] * x[i];
			tmp2 += x[i] * x[i];
		}
		sb[j].set(GRB_DoubleAttr_LB, 0);
		qb[j].set(GRB_DoubleAttr_LB, 0);
		model.addConstr(sb[j] >= tmp - pr.b[j] + qb[j]);
		model.addConstr(sb[j] >= tmp - pr.b[j]);
		model.addQConstr(qb[j] * qb[j] >= tmp2);
		obj += this->user_parameter->beta * sb[j];
	}
	model.setObjective(obj);
	model.set(GRB_IntParam_LogToConsole, 0);
	model.optimize();
	double loss = model.get(GRB_DoubleAttr_ObjVal);
	this->solution->best_loss[sample_order] = loss;

	return loss;
}

void Loss_generator::get_solution(int sample_order, normalclass nc, posterior p) {
	double loss = 0;
	GRBModel model = GRBModel(env);
	GRBVar* x = model.addVars(n_label);
	GRBVar* sx = model.addVars(n_label);
	GRBVar* sb = model.addVars(5);
	GRBVar* qb = model.addVars(5);
	GRBQuadExpr obj;
	for (int i = 0;i < n_label;i++)
	{
		x[i].set(GRB_DoubleAttr_LB, 0);
		obj += pr.k[i] * x[i] * x[i] - (pr.p[i] + pr.k[i] * this->nn_data->batch_expected[i][sample_order]) * x[i];
		sx[i].set(GRB_DoubleAttr_LB, 0);
		if (p.postprob[i][nc.nclass[i][sample_order]][1] > this->user_parameter->th)
			model.addConstr(sx[i] >= x[i]);
		obj += this->user_parameter->beta * sx[i];
	}
	for (int j = 0;j < 5;j++)
	{
		GRBLinExpr tmp;
		GRBQuadExpr tmp2;
		for (int i = 0;i < n_label;i++)
		{
			tmp += pr.A[i][j] * x[i];
			tmp2 += x[i] * x[i];
		}
		sb[j].set(GRB_DoubleAttr_LB, 0);
		qb[j].set(GRB_DoubleAttr_LB, 0);
		model.addConstr(sb[j] >= tmp - pr.b[j] + qb[j]);
		model.addConstr(sb[j] >= tmp - pr.b[j]);
		model.addQConstr(qb[j] * qb[j] >= tmp2);
		obj += this->user_parameter->beta * sb[j];
	}
	model.setObjective(obj);
	model.set(GRB_IntParam_LogToConsole, 0);
	model.optimize();
	for (int i = 0;i < n_label;i++)
	{
		this->solution->mainsol_batch[sample_order][i][0] = x[i].get(GRB_DoubleAttr_X);
		this->solution->vio[i] += (this->solution->mainsol_batch[sample_order][i][0] > 1e-9) * (this->nn_data->batch_ans[i][sample_order] < 3);
	}
}



double Loss_generator::calc_approx_loss_batch(int sample_order, normalclass nc, posterior p, int is_main) { 	double loss = 0, loss_post = 0;
	for (int l = 0;l < n_label;l++)
	{
		if (is_main &&sample_order==0)
			printf("%lf\t%lf\n", p.postprob[l][0][1], p.postprob[l][1][0]);
		if (p.postprob[l][0][1] > this->user_parameter->th)
			loss_post += this->user_parameter->beta * log(p.postprob[l][0][1] / this->user_parameter->th);
	}
	GRBModel model = GRBModel(env);
	GRBVar* x = model.addVars(n_label);
	GRBVar* sx = model.addVars(n_label);
	GRBVar* sb = model.addVars(5);
	GRBVar* qb = model.addVars(5);
	GRBQuadExpr obj;
	for (int i = 0;i < n_label;i++)
	{
		double tmpth = p.postprob[i][0][1] < p.postprob[i][1][1] ? p.postprob[i][0][1] : p.postprob[i][1][1];
		x[i].set(GRB_DoubleAttr_LB, 0);
		obj += pr.k[i] * x[i] * x[i] - (pr.p[i] + pr.k[i] * this->nn_data->batch_expected[i][sample_order]) * x[i];
		sx[i].set(GRB_DoubleAttr_LB, 0);
		if (p.postprob[i][nc.nclass[i][sample_order]][1] > tmpth && p.postprob[i][nc.nclass[i][sample_order]][1] > this->user_parameter->th)
			model.addConstr(sx[i] >= x[i]);
		obj += this->user_parameter->beta * sx[i];
	}
	for (int j = 0;j < 5;j++)
	{
		GRBLinExpr tmp;
		GRBQuadExpr tmp2;
		for (int i = 0;i < n_label;i++)
		{
			tmp += pr.A[i][j] * x[i];
			tmp2 += x[i] * x[i];
		}
		sb[j].set(GRB_DoubleAttr_LB, 0);
		qb[j].set(GRB_DoubleAttr_LB, 0);
		model.addConstr(sb[j] >= tmp-pr.b[j]+qb[j]);
		model.addConstr(sb[j] >= tmp - pr.b[j]);
		model.addQConstr(qb[j] * qb[j] >= tmp2);
		obj += this->user_parameter->beta * sb[j];
	}
	model.setObjective(obj);
	model.set(GRB_IntParam_LogToConsole, 0);
	model.optimize();
	if(is_main)
		for (int i = 0;i < n_label;i++)
		{
			this->solution->mainsol_batch[sample_order][i][0] = x[i].get(GRB_DoubleAttr_X);
			this->solution->vio[i] += (this->solution->mainsol_batch[sample_order][i][0] > 1e-9) * (this->nn_data->batch_ans[i][sample_order] < 3);
		}
	
	loss -= model.get(GRB_DoubleAttr_ObjVal) / this->user_parameter->lambda;

	for (int i = 0;i < n_label;i++)
		obj += this->user_parameter->lambda * (pr.k[i] * x[i] * x[i] - (pr.p[i] + pr.k[i] * this->nn_data->batch_ans[i][sample_order]) * x[i]);
	model.setObjective(obj);
	model.optimize();
	
	loss += model.get(GRB_DoubleAttr_ObjVal) / this->user_parameter->lambda;
	if (is_main)
		for (int i = 0;i < n_label;i++)
			this->solution->mainsol_batch[sample_order][i][1] = x[i].get(GRB_DoubleAttr_X);
	loss += loss_post;
	if (is_main)
		this->solution->loss[sample_order] = loss;
	return loss;
}

#endif